"""
FEniCS tutorial demo program: Poisson equation with Dirichlet conditions.
Test problem is chosen to give an exact solution at all nodes of the mesh.

  -Laplace(u) = f    in the unit square
            u = u_D  on the boundary

  u_D = 1 + x^2 + 2y^2
    f = -6


ref: https://www.iesensor.com/blog/2017/05/24/gmsh_fenics_meshing/
     https://comphysblog.wordpress.com/2018/08/15/fenics-2d-electrostatics-with-imported-mesh-and-boundaries/
"""

from __future__ import print_function
from fenics import *
from dolfin import *
import matplotlib.pyplot as plt
import matplotlib as mpl
import os
#from vtkplotter.dolfin import plot
import numpy as np
import math

class PlasmaHModel:
    def __init__(self, efit):
        self.efit = efit

    def run(self):
        #fname = "t1"
        fname = "hl2a"
        mesh = Mesh(fname+".xml")
        if os.path.exists( fname+"_physical_region.xml"):
            subdomains = MeshFunction("size_t", mesh, fname+"_physical_region.xml")
            #plot(subdomains)
        if os.path.exists( fname+"_facet_region.xml"):
            boundaries = MeshFunction("size_t", mesh, fname+"_facet_region.xml")
            #plot(boundaries)

        #plot(mesh)
        #plt.show()


        print("xml mesh reading done")


        #define function space
        N = FiniteElement('P', mesh.ufl_cell(), 1)
        V = FiniteElement('P', mesh.ufl_cell(), 2)

        B = FunctionSpace(mesh, 'P', 2)
        BB = VectorFunctionSpace(mesh, "P", 2)

        element = MixedElement([N,V])
        W = FunctionSpace(mesh, element)


        #set magnetic field unit vector: b
        dofmap = B.dofmap()
        dofs = dofmap.dofs()
        # Get coordinates as len(dofs) x gdim array
        gdim = 2
        dofs_x = B.tabulate_dof_coordinates().reshape((-1, gdim))
        print("shape of dofs_x: ", dofs_x.shape)

        b = Function(BB)
        b0 = Function(B)
        b1 = Function(B)
        #b_array = np.array(b.vector())
        b0_array = b0.vector().get_local()
        b1_array = b1.vector().get_local()
        print("shape of b0_array: ", b0_array.shape)

        for i in range(b0_array.shape[0]):
            x0 = dofs_x[i,0]
            x1 = dofs_x[i,1]
            b_temp = self.efit.get_B_rz(np.array([x0,x1]))

            b0_array[i] = b_temp[0] / math.sqrt(b_temp[0] * b_temp[0] + b_temp[1] * b_temp[1])
            b1_array[i] = b_temp[1] / math.sqrt(b_temp[0] * b_temp[0] + b_temp[1] * b_temp[1])
            #print("b_array ", b_array[i])
        b0.vector().set_local(b0_array)
        b1.vector().set_local(b1_array)
        assign(b, [b0, b1])


        #Define boundary condition
        #the parameter after boundaries, see hl2a_facet_region.xml file

        #the parameter after boundaries, see hl2a_facet_region.xml file
        n_inner_edge_boundary = DirichletBC(W.sub(0), Constant(1.0e18), boundaries, 384)
        #n_first_wall_boundary = DirichletBC(W.sub(0), Constant(1.0e6), boundaries, 385)
        #n_divertor_ll_boundary = DirichletBC(W.sub(0), Constant(1.0e8), boundaries, 386)
        #n_divertor_lr_boundary = DirichletBC(W.sub(0), Constant(1.0e8), boundaries,387)
        #n_dome_l_boundary = DirichletBC(W.sub(0), Constant(1.0e16), boundaries, 388)

        v_inner_edge_boundary = DirichletBC(W.sub(1), Constant(0.0), boundaries, 384)
        #v_first_wall_boundary = DirichletBC(W.sub(1), Constant(1.0e1), boundaries, 385)
        v_divertor_ll_boundary = DirichletBC(W.sub(1), Constant(1.0e4), boundaries, 386)
        v_divertor_lr_boundary = DirichletBC(W.sub(1), Constant(-1.0e4), boundaries, 387)
        #v_dome_l_boundary = DirichletBC(W.sub(1), Constant(1.0e1), boundaries, 388)

        #bcs =[n_inner_edge_boundary, n_divertor_ll_boundary, n_divertor_lr_boundary, v_inner_edge_boundary, v_divertor_ll_boundary, v_divertor_lr_boundary]
        bcs =[n_inner_edge_boundary, v_inner_edge_boundary, v_divertor_ll_boundary, v_divertor_lr_boundary]



        # Facet normal, identity tensor and boundary measure
        #n = FacetNormal(mesh)
        #I = Identity(mesh.geometry().dim())
        #ds = Measure("ds", subdomain_data=bndry)
        #nu = Constant(0.001)

        dt = 0.1
        t_end = 10.0
        theta=1.0   # Crank-Nicholson timestepping
        k=0.01

        #diffusion coeffiecent
        D = 10.0
        Dp = 1000.0
        Dr = 1.0
        mi = 1.0
        niup = 1000.0
        niur = 1.0
        T = 100.0 * 1.62e-19


        # Define unknown and test function(s)
        (n_, v_) = TestFunctions(W)
        (nt, vt) = TrialFunctions(W)


        w = Function(W)
        #(n, v) = w.split()
        (n, v) = split(w)


        #print(type(v), type(p), type(wt))
        print(type(n_), type(v_))
        # previous known time step
        #w0 = Function(W)
        #(n0, v0) = split(w0)
        #(v0, p0, e0) = TrialFunctions(W)
        
        
        h = 2.0 * Circumradius(mesh)
        c = 1.0e-15
        vnorm = sqrt(inner(v, v))
        tau = h/(2.0*vnorm + c) # tau from SUPG fenics example
        #tau = pow(2.0*vnorm/h + 4*c/pow(h,2.0),-1) # tau from


        #GLS factor
        C = 5.0
        hh = inner(h, h)
        print(type(h), type(vnorm), type(tau))

        # Define variational forms without time derivative in previous time
        F0_eq1 = 0.0
        F0_eq1 = F0_eq1 + div(n * v * b) * n_ * dx
        #F0_eq1 = F0_eq1 - Dp * inner( b * dot(b, nabla_grad(n)), nabla_grad(n_))*dx
        F0_eq1 = F0_eq1 + ( Dr * dot(nabla_grad(n), nabla_grad(n_))*dx - Dr * dot(b * dot(b, nabla_grad(n)) , nabla_grad(n_)) * dx )
        #F0_eq1 = F0_eq1 + 2.0 * dot(nabla_grad(n), nabla_grad(n_))*dx
        #F0_eq1 = F0_eq1 + 100.0 * ( h * h ) * inner( dot(b * v, nabla_grad(n)),  dot(b * v, nabla_grad(n_)) ) * dx
        #GLS term
        F0_eq1 = F0_eq1 + C * hh * div(n * v * b) * div(n_ * v * b) * dx

        #F0_eq2 = 0.0
        #F0_eq2 = F0_eq2 + inner(div(n * v * v * b), v_) * dx 
        #F0_eq2 = F0_eq2 - ( niur * dot(nabla_grad(n * v), nabla_grad(v_))*dx - niur * dot(b * dot(b, nabla_grad(n * v)) , nabla_grad(v_)) * dx )
        #F0_eq2 = F0_eq2 + 0.02 * dot(nabla_grad(v), nabla_grad(v_))*dx

        #F0_eq2 = 0.0
        #F0_eq2 = F0_eq2 + inner(div(v * v * b), v_) * dx 
        #F0_eq2 = F0_eq2 - ( niur * dot(nabla_grad(v), nabla_grad(v_))*dx - niur * dot(b * dot(b, nabla_grad(v)) , nabla_grad(v_)) * dx )
        #F0_eq2 = F0_eq2 + 20.0 * dot(nabla_grad(v), nabla_grad(v_))*dx

        #BOUT++
        #F0_eq2 = 0.0
        #F0_eq2 = F0_eq2 + div(n * v * v * b) * v_ * dx 
        #F0_eq2 = F0_eq2 + ( niur * v * dot(nabla_grad(n), nabla_grad(v_))*dx - niur * v * dot(b * dot(b, nabla_grad(n)) , nabla_grad(v_)) * dx )
        #GLS term
        #F0_eq2 = F0_eq2 + C * hh * div(n * v * v * b) * div(n * v * v_ * b) * dx

        #burgers
        #F0_eq2 = 0.0
        #F0_eq2 = F0_eq2 + inner(div(v*b)*v, v_) * dx
        #F0_eq2 = F0_eq2 + ( niur * dot(nabla_grad(v), nabla_grad(v_))*dx - niur * dot(b * dot(b, nabla_grad(v)) , nabla_grad(v_)) * dx )
        #GLS term
        #F0_eq2 = F0_eq2 + C * hh * inner(div(v*b)*v, div(v_*b)*v) * dx


        #diffusion
        F0_eq2 = 0.0
        F0_eq2 = F0_eq2 + niup * dot( b * dot(b, nabla_grad(v)), nabla_grad(v_))*dx
        F0_eq2 = F0_eq2 + ( niur * dot(nabla_grad(v), nabla_grad(v_))*dx - niur * dot(b * dot(b, nabla_grad(v)) , nabla_grad(v_)) * dx)


        F0 = F0_eq1 + F0_eq2
        # variational form without time derivative in current time
        #F1_eq1 = 
        #F1_eq2 = 
        #F1 = F1_eq1 + F1_eq2


        #combine variational forms with time derivative
        #F = (inner((v-v0),v_)/dt + inner((e-e0),e_)/dt)*dx + (1.0-theta)*F0  + theta*F1
        F = F0

        J = derivative(F, w)
        #problem=NonlinearVariationalProblem(F,w,bcs, J)
        #solver=NonlinearVariationalSolver(problem)

        #prm = solver.parameters
        #info(prm,True)  #get full info on the parameters
        #prm['nonlinear_solver'] = 'newton'
        #prm['newton_solver']['absolute_tolerance'] = 1E-12
        #prm['newton_solver']['relative_tolerance'] = 1e-12
        #prm['newton_solver']['maximum_iterations'] = 200
        #prm['newton_solver']['linear_solver'] = 'petsc'
        #list_linear_solver_methods()

        begin("Solving ....")
        #solver.solve()
        solve(F == 0, w, bcs, J=J)
        #solve(F == 0, w, J=J)

        #end()

        # Extract solutions:

        (n, v) = w.split()

        
        fig = plt.figure()

        plt.subplot(1, 3, 1)
        c = plot(n, mode='color', cmap = mpl.cm.Spectral_r)
        plt.colorbar(c)

        plt.subplot(1, 3, 2)
        c = plot(v, mode='color', cmap = mpl.cm.Spectral_r)
        plt.colorbar(c)

        plt.subplot(1, 3, 3)
        c = plot(b, scale_units = 'xy', scale = 1.0e10)

        fig.savefig("nv.pdf")
        plt.show()